import re
from collections import defaultdict
from pathlib import Path

from buttercup.common.challenge_task import ChallengeTask
from buttercup.program_model.utils.common import Function, TypeDefinition


class FuzzyCImportsResolver:
    """A resolver for C imports in a source code folder.
    This class can analyze #include statements in files and build dependency trees.

    WARNING: The class resolves imports using very naive file parsing and is best effort, it is not
    sound in nature as it doesn't take into account compile flags and compile commands
    actually run to compile projects.
    """

    def __init__(self, root_dir: Path):
        """Initialize the resolver with a root source code folder.

        Args:
            root_dir: The absolute path to the root folder containing the source code

        """
        # The source code directory to find imports in. This is typically the task src dir
        self.root_dir = root_dir
        # Cache of file imports to avoid re-parsing files
        self.direct_imports_cache: dict[Path, set[Path]] = {}
        self.all_imports_cache: dict[Path, set[Path]] = {}
        # Internal variable used while resolving imports
        self._tmp_imports: set[Path] = set()

    def _normalize_path(self, path: Path | str) -> Path:
        """Normalize a path into an absolute path"""
        # Codequery rebases paths with an initial / which messes
        # with the search paths for files. So here if a file path
        # starts with / we check wether the file exists. If not then
        # it is not an absolute path but rather a relative path
        # from the root_dir and thus we remove the /
        if str(path).startswith("/"):
            if not Path(path).exists():
                path = Path(str(path)[1:])
        # Convert to Path if a string object was supplied
        path = Path(path)
        # If path not absolute, rebase from root_dir
        if not path.is_absolute():
            path = (self.root_dir / path).resolve()
        return path

    def _find_file_in_codebase(self, import_name: str, origin_file: Path) -> Path | None:
        """Find the actual file path for an imported file name.
        This method handles different include styles:
        - System includes like <stdio.h> (ignored as external)
        - Relative includes like "utils/helper.h"
        - Project includes

        Args:
            import_name: The name of the import as it appears in the #include statement
            origin_file: The file that contains the import statement

        Returns:
            The absolute path to the imported file or None if not found

        """
        # Ignore system includes (enclosed in <>)
        if import_name.startswith("<") and import_name.endswith(">"):
            return None

        # Remove quotes if present
        if import_name.startswith('"') and import_name.endswith('"'):
            import_name = import_name[1:-1]

        # Try relative to the origin file's directory
        origin_dir = origin_file.parent
        candidate = (origin_dir / import_name).resolve()
        if candidate.is_file():
            return candidate

        return None

    def get_direct_imports(self, file_path: Path) -> set[Path]:
        """Parse a file for import statements and try to resolve them
        to actual files present in the codebase.

        Args:
            file_path: The path to the file to parse

        Returns:
            A list of paths to the imported files if they have been successfully
            found in the code directory

        """
        file_path = self._normalize_path(file_path)

        # Check if we already have this file's imports in cache
        if file_path in self.direct_imports_cache:
            return self.direct_imports_cache[file_path]
        try:
            content = file_path.read_text()

            # Find all #include statements using regex
            # This pattern matches both #include <file> and #include "file"
            include_pattern = re.compile(r'#include\s+([<"]([^>"]+)[>"])')
            matches = include_pattern.findall(content)

            # Extract the import names and try to resolve their paths
            imports = set()
            for _, import_name in matches:
                # NOTE(boyan): import_name is the import as present in the source code,
                # for example "somefile.h" from '#include "somefile.h"'. Sometimes the file
                # "somefile.h" doesn't exist but is templated and generated by the build toolchain.
                # Here we add other typical variants of the file name that correspond to
                # typical patterns found in repositories
                # NOTE(boyan): the candidate files below are checked in order and we select the
                # first match.
                # TODO(boyan): see wether there are other common extensions we should add
                import_files_candidates = [import_name, import_name + ".in"]

                for file_candidate in import_files_candidates:
                    resolved_path = self._find_file_in_codebase(file_candidate, file_path)
                    if resolved_path:
                        imports.add(resolved_path)
                        break

            self.direct_imports_cache[file_path] = imports
            return imports

        except Exception as e:
            print(f"Error parsing {file_path}: {e}")
            return set()

    def get_all_imports(self, file_path: Path, depth: int | None = None) -> set[Path]:
        """Recursively get all files imported by a given file.

        Args:
            file: The path to the file (relative to root or absolute)
            depth: Maximum depth to traverse, None for unlimited

        Returns:
            A list of absolute paths to all imported files

        """
        file_path = self._normalize_path(file_path)
        if file_path in self.all_imports_cache:
            return self.all_imports_cache[file_path]
        # Reset processing set for new recursive call
        self._tmp_imports = set()
        self._get_all_imports_recursive(file_path, depth)
        self.all_imports_cache[file_path] = self._tmp_imports
        return self._tmp_imports

    def _get_all_imports_recursive(self, file_path: Path, depth: int | None, current_depth: int = 0) -> None:
        """Recursive helper for get_all_imports.

        Args:
            file_path: The absolute path to the file
            depth: Maximum depth to traverse, None for unlimited
            current_depth: Current recursion depth

        """
        # Check if we've reached the maximum depth
        if depth is not None and current_depth >= depth:
            return

        # Detect import loops
        if file_path in self._tmp_imports:
            return  # Skip this file to break the loop

        # Mark this file as being done
        self._tmp_imports.add(file_path)
        immediate_imports = self.get_direct_imports(file_path)

        # Recursively get imports of imports
        for import_file in immediate_imports:
            # Add nested imports. This should return only new imports
            self._get_all_imports_recursive(import_file, depth, current_depth + 1)
            # Record actual import of current file
            self._tmp_imports.add(import_file)

    def is_file_imported_by(self, imported_file_path: Path, file_path: Path) -> bool:
        """Return True if imported_file_path is imported by file_path (either directly or indirectly through
        nested imports)
        """
        all_imports = self.get_all_imports(self._normalize_path(file_path))
        return self._normalize_path(imported_file_path) in all_imports

    def filter_callees(self, caller_function: Function, callees: list[Function]) -> list[Function]:
        """This filters callees found for a given function. This function is used to deduplicate
        callees when several are found with identical names. This happens because Codequery does a
        syntactic search only and can't resolve which function is actually imported and called
        by the caller.

        This function is best effort, it takes in a list of callees and returns a list of callees.
        """
        # Group callees by function name. Duplicate callees with the same
        # name end-up in the same group
        callee_groups = defaultdict(list)
        for callee in callees:
            callee_groups[callee.name].append(callee)
        # Filter callees for each callee name
        res = []
        for group in callee_groups.values():
            if len(group) <= 1:
                # Only one callee with this name, no dedup needed
                res += group
            else:
                added_at_least_one = False
                # Multiple callees with same name, check which ones are actually imported
                for callee in group:
                    added = False
                    # NOTE(boyan): the extensions in the code for decl files below is
                    # specific to C. CHange it depending on programming language.

                    # Here we take the file in whih the potential callee is defined
                    # (usually a .c file) and try to guess in which header files it
                    # be declared
                    callee_file = callee.file_path
                    callee_file_h = str(callee.file_path).replace(".c", ".h")
                    possible_decl_files = [
                        callee_file,
                        Path(callee_file_h),
                        Path(callee_file_h + ".in"),
                        Path(callee_file_h + "h"),
                    ]
                    # For each potential declaration file, see whether it is imported by the caller
                    for decl_file in possible_decl_files:
                        if self.is_file_imported_by(decl_file, caller_function.file_path):
                            res.append(callee)
                            added = True
                            break

                    # If we couldn't resolve the import for this callee, see if its potential
                    # declaration file name is present in caller imports after discarding paths
                    # I.e if funcA is defined in A.h and the caller func imports src/foo/A.h, still
                    # consider this a match.
                    # This loose approach prevents from aggressively filtering out imported files
                    # in projects where import dirs are managed with compilation flags
                    if not added:
                        for decl_file in possible_decl_files:
                            if any(
                                x
                                for x in self.get_all_imports(caller_function.file_path)
                                if str(x).endswith(f"/{Path(decl_file).name}")
                            ):
                                res.append(callee)
                                added = True
                                break

                    added_at_least_one |= added

                # If we couldn't find even one callee that is correctly imported
                # by the caller, just return all of them as codequery originally
                # found them. This scenario means that we have hit a shortcoming of
                # the fuzzy import resolver. It is then better to add all found callees
                # and assume the correct one is present, rather than none of them.
                # At the end of the day we want to give the model some material to work with
                if not added_at_least_one:
                    res += group
        return res


class FuzzyJavaImportsResolver:
    """A resolver for Java imports in a source code folder.
    This class can analyze import statements in files and lookup class defs to
    find out the actual methods that are called in a given caller file or function.
    It it used to deduplicate code search results for called functions and callee
    search.

    WARNING: The class has some limitations and resolves methods in a best effort manner.
    If it can't filter provided methods then it returns the same unfiltered function list
    that it was given as input
    """

    def __init__(self, challenge: ChallengeTask, codequery: "CodeQuery"):  # type: ignore # noqa: F821
        # TODO(boyan): make sure these paths hold for the competition
        # Path where the challenge source is mounted in the ossfuzz repo
        # according to docker file
        if challenge:
            self.container_code_path = Path(str(challenge.workdir_from_dockerfile())[1:])
            # Path where the challenge source is on the local machine
            self.local_code_path = challenge.task_dir / "container_src_dir" / "src" / challenge.focus
        self.codequery = codequery

    def get_package_from_file(self, file_path: Path) -> str | None:
        """Get the package name from a file path"""
        # Parse lines to find one that starts with "package"
        with open(self._normalize_path(file_path)) as f:
            for line in f:
                if line.startswith("package"):
                    return line.split(" ")[1].strip()
        return None

    def get_dotexpr_type(self, dotexpr: str, file_path: Path) -> TypeDefinition | None:
        """Get the type of a dot expression."""
        # If the dotexpr has no dots, see if it is an imported class
        if "." not in dotexpr:
            # Get imports from the caller file
            imports = self.parse_imports_in_file(self._normalize_path(file_path))
            file_package = self.get_package_from_file(file_path)
            if file_package is None:
                return None

            # Keep only imports that end with the expression
            imports = [imp for imp in imports if imp.endswith(f".{dotexpr}")]
            if imports:
                imp = imports[0]
            else:
                # No imports, maybe the type is in the same package and
                # thus not explicitly imported.
                # TODO(boyan): I think the proper way to do this would be
                # to add all the files in the package instead of guessing
                # the name of the file based on the class name
                imp = file_package + "." + dotexpr

            # Get path of file from where the import is made
            # TODO(boyan): make sure we can assume files end with .java here
            # First transform import statement to corresponding file in the code base
            imported_file = "../" * (file_package.count(".") + 1) + imp.replace(".", "/") + ".java"
            imported_file = (file_path.parent / imported_file).resolve()
            # Then try to get type from that file and return it
            return self.get_type_from_file(imported_file, dotexpr)

        # If the dotexpr has dots, iteratively get the type of the prefix
        prefix, suffix, expr_type = self.split_rightmost_dotexpr(dotexpr)
        prefix_type = self.get_dotexpr_type(prefix, file_path)
        if prefix_type is None:
            return None
        # Parse the type definition to find the field/method type
        if expr_type == "field":
            field_type_name = self.get_field_type_name(prefix_type, suffix)
            if field_type_name is None:
                return None
            # Then get the actual type definition. Do do this we resolve the
            # type of the field within the file where the prefix type is defined.
            # E.G for Foo.a if we now type name of field a is Bar then we
            # look for the type Bar that is imported in /path/to/Foo.java
            res = self.get_dotexpr_type(field_type_name, prefix_type.file_path)
            return res
        if expr_type == "method":
            # TODO(boyan): resolve class methods, here we assume it's a method
            method_return_type_name = self.get_method_return_type_name(prefix_type, suffix)
            if method_return_type_name is None:
                return None
            res = self.get_dotexpr_type(method_return_type_name, prefix_type.file_path)
            return res
        # Should not happen
        return None

    def split_rightmost_dotexpr(self, expr: str) -> tuple[str, str, str | None]:
        """Splits a dot expression into two parts:
        1. The rest of the left side of the expression unmodified
        2. The rightmost toplevel field or method name after the last dot

        Args:
            expr: A string like "a.b(c).d(e())"

        Returns:
            A tuple of (left_part, right_part, type), e.g. ("a.b(c)", "d", <type>)
            where type is either "field", "method", or None

        """
        if not expr:
            return "", "", None

        # Scan the expression from right to left, tracking parentheses levels
        paren_level = 0
        last_dot_index = -1
        for i in range(len(expr) - 1, -1, -1):
            char = expr[i]
            if char == ")":
                paren_level += 1
            elif char == "(":
                paren_level -= 1
            elif char == "." and paren_level == 0:
                # We found the last top-level dot
                last_dot_index = i
                break

        # The left part is everything before the last dot
        left_part = expr[:last_dot_index] if last_dot_index != -1 else ""

        # For the right part, we need to extract just the identifier (until a parenthesis or end)
        right_start = last_dot_index + 1
        right_end = right_start
        while right_end < len(expr) and expr[right_end] != "(":
            right_end += 1
        t = "method" if right_end < len(expr) and expr[right_end] == "(" else "field"
        right_part = expr[right_start:right_end]
        return left_part, right_part, t

    def get_field_type_name(self, t: TypeDefinition, field_name: str) -> str | None:
        """Parse the type definition to find the field type name"""
        type_body = t.definition.encode("utf-8")
        type_name = self.codequery.ts.get_field_type_name(type_body, field_name)
        return type_name  # type: ignore[no-any-return]

    def get_method_return_type_name(self, t: TypeDefinition, method_name: str) -> str | None:
        """Parse the type definition to find the method return type name"""
        type_body = t.definition.encode("utf-8")
        type_name = self.codequery.ts.get_method_return_type_name(type_body, method_name)
        return type_name  # type: ignore[no-any-return]

    def get_type_from_file(self, file_path: Path, type_name: str) -> TypeDefinition | None:
        """Get the type definition given a type name and a file path
        file_path must a container path (e.g. /src/log4j-core/...)
        """
        types = self.codequery.get_types(
            type_name,
            file_path,
        )
        if not types:
            return None
        # Return first type found
        return types[0]  # type: ignore[no-any-return]

    def filter_callees(self, caller_function: Function, callees: list[Function]) -> list[Function]:
        callee_groups = defaultdict(list)
        for callee in callees:
            callee_groups[callee.name].append(callee)
        res = []
        # Filter all callees with the same name
        for callee_name, group in callee_groups.items():
            if len(group) <= 1:
                # Only one callee with this name, no dedup needed
                res += group
            else:
                added_at_least_one = False
                # Multiple callees with same name, check which ones are the real ones
                # Do only once per name...

                # Get call "prefixes". If a method is called with a.b.c.d() the prefix
                # is a.b.c. We use this to determine which class or file the d() method
                # belongs to
                prefixes = self.try_extract_call_expr_prefix(caller_function, callee_name)
                if not prefixes:
                    continue

                # Get all types for prefixes found
                prefixes_types = [self.get_dotexpr_type(prefix, caller_function.file_path) for prefix in prefixes]
                # Filter out None types (couldn't find the type)
                prefixes_types = [t for t in prefixes_types if t is not None]
                if not prefixes_types:
                    continue
                # Get all files where there is a definition of the callee within
                # a type that matches a call prefix inside the caller.
                prefixes_files = [t.file_path for t in prefixes_types]
                # Get all callees that are defined in any of the prefix files
                for callee in group:
                    if callee.file_path in prefixes_files:
                        res.append(callee)
                        added_at_least_one = True

                # If we couldn't find even one callee that is correctly imported
                # by the caller, just return all of them as codequery originally
                # found them. This scenario means that we have hit a shortcoming of
                # the fuzzy import resolver. It is then better to add all found callees
                # and assume the correct one is present, rather than none of them.
                # At the end of the day we want to give the model some material to work with
                if not added_at_least_one:
                    res += group

        return res

    def filter_callees2(self, caller_function: Function, callees: list[Function]) -> list[Function]:
        callee_groups = defaultdict(list)
        for callee in callees:
            callee_groups[callee.name].append(callee)
        res = []
        # Filter all callees with the same name
        for callee_name, group in callee_groups.items():
            if len(group) <= 1:
                # Only one callee with this name, no dedup needed
                res += group
            else:
                added_at_least_one = False
                # Multiple callees with same name, check which ones are actually imported
                # Do only once per name...
                # Get imports from the caller file
                imports = self.parse_imports_in_file(self._normalize_path(caller_function.file_path))
                # Get call "prefixes". If a method is called with a.b.c.d() the prefix
                # is a.b.c. We use this to determine which class or file the d() method
                # belongs to
                prefixes = self.try_extract_call_expr_prefix(caller_function, callee_name)
                if not prefixes or not imports:
                    continue
                # Keep only imports that match prefixes for that function name
                # This means that given the import "org.foo.bar.Stuff"; we keep it
                # only if there is at least one callee that is called with "Stuff.<callee_name>(...)".

                # TODO(boyan): need to refactor this when we support recursively
                # exploring the prefixes with multiple dots
                imports = [imp for imp in imports if any(pref for pref in prefixes if imp.endswith(f".{pref}"))]

                # At this point we have only imports that match with the prefix of a called
                # function with name callee_name in the caller body, we now proceed to add any callee
                # that comes from the imported file
                for imp in imports:
                    # TODO(boyan): make sure we can assume files end with .java here
                    # First transform import statement to corresponding file in the code base
                    imported_file = "../" * (imp.count(".") + 1) + imp.replace(".", "/") + ".java"
                    imported_file = self._normalize_path(caller_function.file_path.parent) / imported_file
                    imported_file = imported_file.resolve()
                    for callee in group:
                        added = False
                        # See if this is an import
                        callee_file = self._normalize_path(callee.file_path)
                        # The callee is defined in a file that matches its prefix
                        # and is imported in the caller file, so it's a real callee
                        # and we keep it
                        if imported_file == callee_file:
                            res.append(callee)
                            added = True
                        added_at_least_one |= added

                # If we couldn't find even one callee that is correctly imported
                # by the caller, just return all of them as codequery originally
                # found them. This scenario means that we have hit a shortcoming of
                # the fuzzy import resolver. It is then better to add all found callees
                # and assume the correct one is present, rather than none of them.
                # At the end of the day we want to give the model some material to work with
                if not added_at_least_one:
                    res += group
        return res

    def try_extract_call_expr_prefix(self, caller: Function, callee_name: str) -> list[str]:
        """Try to extract all call prefixes of a function called in the caller body.
        If the caller is:
        public void foo() {
            a.a.b(c.d.e());
            foo.b(3);
        }
        and the callee_name is "b" then the result is ["a.a", "foo"]
        """
        # TODO(boyan): handle the case where function is called directly
        # without a leading '.'
        call_marker = f".{callee_name}("
        caller_body = caller.bodies[0].body
        res = []
        for expr in caller_body.split():
            if call_marker in expr:
                prefix = expr.split(call_marker)[0]
                # Handle the case where the call is within another call like
                # a.b(c.d.e())
                if "(" in prefix:
                    prefix = prefix.split("(")[-1]
                res.append(prefix)
        return res

    def parse_imports_in_file(self, file_path: Path) -> list[str]:
        """Parse import statements in file and return the list of
        imported strings. E.g for 'import a.b.c;', the function returns
        ['a.b.c']
        """
        import_pattern = r"import\s+([\w.]+);"
        res = []
        with open(file_path) as f:
            for line in f:
                # Find all matches in the code
                matches = re.findall(import_pattern, line)
                res += matches
        return res

    def _normalize_path(self, path: Path | str) -> Path:
        """Normalize a path into an absolute path"""
        # Codequery rebases paths with an initial / which messes
        # with the search paths for files. So here if a file path
        # starts with / we check wether the file exists. If not then
        # it is not an absolute path but rather a relative path
        # from the root_dir and thus we remove the /
        if str(path).startswith("/"):
            if not Path(path).exists():
                path = Path(str(path)[1:])
        # Convert to Path if a string object was supplied
        path = Path(path)
        # Transform container-paths into local paths
        if str(path).startswith(str(self.container_code_path)):
            path = Path(str(path)[len(str(self.container_code_path)) + 1 :])
        # If path not absolute, rebase from local source dir
        if not path.is_absolute():
            path = (self.local_code_path / path).resolve()
        return path

    def _relative_path(self, path: Path) -> Path:
        """Convert a path to a relative path"""
        p = str(path)
        # Remove local source path
        if p.startswith(str(self.local_code_path)):
            p = p[len(str(self.local_code_path)) + 2 :]
        # And add container source path
        if not p.startswith(str(self.container_code_path)):
            p = str(self.container_code_path) + "/" + p
        return Path(p)
